iT邦幫忙

2023 iThome 鐵人賽

DAY 16
0

3. SFT訓練與Loss計算

sft訓練的程式碼sft.py

以下為簡化過的程式碼,X, Y, loss_mask的內容如何產生已經在前面介紹SFT dataset的時候介紹過了,所以SFT訓練部份主要的重點剩下Loss的計算:

  • ignore_index=0: 忽略掉padding_token_id對應的Loss
  • reduce=False: 這邊只是暫時不做reduce,然而後面乘完loss_mask以後還是會做reduce,這樣每一個訓練樣本對於模型的影響才會是一致的
  • loss = torch.sum(loss*loss_mask)/loss_mask.sum()
    • torch.sum(loss*loss_mask): 僅計算answer對應tokens的loss
    • torch.sum(...)/loss_mask.sum(): 最後對loss做mean reduce
def train_epoch(epoch):
    start_time=time.time()
    for step, (X, Y,loss_mask) in enumerate(train_loader):
        ......
        logits = model(X, Y)
        loss = F.cross_entropy(logits, Y, ignore_index=0,reduce=False)
        loss = torch.sum(loss*loss_mask)/loss_mask.sum()
        ......

訓練參數

if __name__=="__main__":
    max_epoch = 10
    batch_size = 32
    # model
    max_seq_len = 512
    dim = 512
    n_layers = 8
    n_heads = 8
    multiple_of = 32
    ......

上一篇
Day 15 - Baby LLama2 Chinese (9) SFT階段
下一篇
Day 17 - Baby LLama2 Chinese (11)
系列文
用單張顯卡探索大型語言模型的奧秘30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言